{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-17T12:49:57.991109Z",
     "iopub.status.busy": "2022-12-17T12:49:57.990713Z",
     "iopub.status.idle": "2022-12-17T12:49:58.002093Z",
     "shell.execute_reply": "2022-12-17T12:49:58.000496Z",
     "shell.execute_reply.started": "2022-12-17T12:49:57.991075Z"
    }
   },
   "source": [
    "#  &#x1F4D1; **HW4 音频分类**\n",
    "- 给定音频区分出说话的人\n",
    "- 主要目标: 学会使用transformer\n",
    "- Baselines:\n",
    "  - Easy: 知道怎么使用transformer, 输出简单的可以运行的脚本.\n",
    "  - Medium: 知道transformer如何调参.\n",
    "  - <font color=darkred><b>Strong: 改变transformer的结构，使用一种transformer变体—— [conformer](https://arxiv.org/abs/2005.08100)  </font></b>\n",
    "  - <font color=darkred><b>Boss: 使用 [Self-Attention Pooling](https://arxiv.org/pdf/2008.01077v1.pdf) & [Additive Margin Softmax](https://arxiv.org/pdf/1801.05599.pdf)进一步提升模型表现. </font></b>\n",
    "\n",
    "\n",
    "- 其他链接：\n",
    "  - Kaggle: [link](https://www.kaggle.com/t/ac77388c90204a4c8daebeddd40ff916)\n",
    "  - Slide: [link](https://docs.google.com/presentation/d/1HLAj7UUIjZOycDe7DaVLSwJfXVd3bXPOyzSb6Zk3hYU/edit?usp=sharing)\n",
    "  - Data: [link](https://drive.google.com/drive/folders/1vI1kuLB-q1VilIftiwnPOCAeOOFfBZge?usp=sharing)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# **加载包**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-18T14:20:27.289682Z",
     "iopub.status.busy": "2022-12-18T14:20:27.288667Z",
     "iopub.status.idle": "2022-12-18T14:20:36.302667Z",
     "shell.execute_reply": "2022-12-18T14:20:36.301471Z",
     "shell.execute_reply.started": "2022-12-18T14:20:27.289638Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: torchviz in /opt/conda/lib/python3.7/site-packages (0.0.2)\n",
      "Requirement already satisfied: torch in /opt/conda/lib/python3.7/site-packages (from torchviz) (1.11.0)\n",
      "Requirement already satisfied: graphviz in /opt/conda/lib/python3.7/site-packages (from torchviz) (0.8.4)\n",
      "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from torch->torchviz) (4.1.1)\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!pip install torchviz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
    "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
    "execution": {
     "iopub.execute_input": "2022-12-18T14:20:36.305765Z",
     "iopub.status.busy": "2022-12-18T14:20:36.305350Z",
     "iopub.status.idle": "2022-12-18T14:20:36.314953Z",
     "shell.execute_reply": "2022-12-18T14:20:36.313884Z",
     "shell.execute_reply.started": "2022-12-18T14:20:36.305726Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "import numpy as np\n",
    "import random\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn \n",
    "from torch.utils.data import Dataset, random_split, DataLoader\n",
    "from torch import functional  as F\n",
    "from torch.optim import Optimizer\n",
    "from torch.optim.lr_scheduler import LambdaLR\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "import math\n",
    "\n",
    "import os\n",
    "import sys\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "# 绘制评估曲线\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import matplotlib.pyplot as plt\n",
    "from torchviz import make_dot\n",
    "\n",
    "import warnings \n",
    "from rich.console import Console\n",
    "warnings.filterwarnings('ignore')\n",
    "cs = Console()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# **下载数据**\n",
    "```python\n",
    "!wget https://github.com/MachineLearningHW/ML_HW4_Dataset/releases/latest/download/Dataset.tar.gz.partaa\n",
    "!wget https://github.com/MachineLearningHW/ML_HW4_Dataset/releases/latest/download/Dataset.tar.gz.partab\n",
    "!wget https://github.com/MachineLearningHW/ML_HW4_Dataset/releases/latest/download/Dataset.tar.gz.partac\n",
    "!wget https://github.com/MachineLearningHW/ML_HW4_Dataset/releases/latest/download/Dataset.tar.gz.partad\n",
    "!cat Dataset.tar.gz.part* > Dataset.tar.gz\n",
    "# unzip the file\n",
    "!tar zxvf Dataset.tar.gz\n",
    "```\n",
    "如果https://github.com/MachineLearningHW/ML_HW4_Dataset/releases/latest/download/ 下载不了可以用以下途径下载数据\n",
    "- [Kaggle Data: ml2022spring-hw4](https://www.kaggle.com/competitions/ml2022spring-hw4/data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# **一些重要的函数**\n",
    "- all_seed 设置随机种子"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-18T14:20:36.317134Z",
     "iopub.status.busy": "2022-12-18T14:20:36.316742Z",
     "iopub.status.idle": "2022-12-18T14:20:36.328224Z",
     "shell.execute_reply": "2022-12-18T14:20:36.327123Z",
     "shell.execute_reply.started": "2022-12-18T14:20:36.317100Z"
    }
   },
   "outputs": [],
   "source": [
    "def model_plot(model_class, input_sample):\n",
    "    clf = model_class()\n",
    "    y = clf(input_sample) \n",
    "    clf_view = make_dot(y, params=dict(list(clf.named_parameters()) + [('x', input_sample)]))\n",
    "    return clf_view\n",
    "\n",
    "\n",
    "def all_seed(seed=6666):\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    # CPU\n",
    "    torch.manual_seed(seed)\n",
    "    # GPU\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "        torch.cuda.manual_seed(seed)\n",
    "    # python全局\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    # cudnn\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    torch.backends.cudnn.enabled = False\n",
    "    print(f'Set env random_seed = {seed}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-18T14:20:36.332213Z",
     "iopub.status.busy": "2022-12-18T14:20:36.331944Z",
     "iopub.status.idle": "2022-12-18T14:20:36.340340Z",
     "shell.execute_reply": "2022-12-18T14:20:36.339277Z",
     "shell.execute_reply.started": "2022-12-18T14:20:36.332189Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Set env random_seed = 87\n"
     ]
    }
   ],
   "source": [
    "all_seed(87)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# **数据集**\n",
    "- 原始数据集 [Voxceleb2](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html).\n",
    "- The [license](https://creativecommons.org/licenses/by/4.0/) and [complete version](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/files/license.txt) of Voxceleb2.\n",
    "- 我们从Voxceleb2数据集中随机抽取600个演讲者 \n",
    "- 将数据原始波形转换为mel谱图\n",
    "\n",
    "- 文件夹的结构如下:\n",
    "  - data directory   \n",
    "  |---- metadata.json    \n",
    "  |---- testdata.json     \n",
    "  |---- mapping.json     \n",
    "  |---- uttr-{random string}.pt   \n",
    "\n",
    "- metadata.json中的信息\n",
    "  - \"n_mels\": 40， mel图谱的维度.\n",
    "  - \"speakers\": 字典. \n",
    "    - Key: speaker ids.\n",
    "    - value: \"feature_path\"-特征文件 and \"mel_len\"-特征的长度\n",
    "\n",
    "\n",
    "为了更加高效, 我们在训练的时候将mel图谱分割成一定的长度。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-18T14:20:36.342846Z",
     "iopub.status.busy": "2022-12-18T14:20:36.341650Z",
     "iopub.status.idle": "2022-12-18T14:20:36.450510Z",
     "shell.execute_reply": "2022-12-18T14:20:36.449187Z",
     "shell.execute_reply.started": "2022-12-18T14:20:36.342809Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">mapping.json | keys =  <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">dict_keys</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008000; text-decoration-color: #008000\">'speaker2id'</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'id2speaker'</span><span style=\"font-weight: bold\">])</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "mapping.json | keys =  \u001b[1;35mdict_keys\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'speaker2id'\u001b[0m, \u001b[32m'id2speaker'\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">metadata.json | keys =  <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">dict_keys</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008000; text-decoration-color: #008000\">'n_mels'</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'speakers'</span><span style=\"font-weight: bold\">])</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "metadata.json | keys =  \u001b[1;35mdict_keys\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'n_mels'\u001b[0m, \u001b[32m'speakers'\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">matedata_js<span style=\"font-weight: bold\">[</span><span style=\"color: #008000; text-decoration-color: #008000\">'n_mels'</span><span style=\"font-weight: bold\">]</span>= <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">40</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "matedata_js\u001b[1m[\u001b[0m\u001b[32m'n_mels'\u001b[0m\u001b[1m]\u001b[0m= \u001b[1;36m40\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">matedata_js<span style=\"font-weight: bold\">[</span><span style=\"color: #008000; text-decoration-color: #008000\">'speakers'</span><span style=\"font-weight: bold\">][</span><span style=\"color: #008000; text-decoration-color: #008000\">'id00559'</span><span style=\"font-weight: bold\">][</span>:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span><span style=\"font-weight: bold\">]</span>=\n",
       "<span style=\"font-weight: bold\">[</span>\n",
       "    <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'feature_path'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'uttr-2918eae600684146903d49f02275cb94.pt'</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'mel_len'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">397</span><span style=\"font-weight: bold\">}</span>,\n",
       "    <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'feature_path'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'uttr-ba3892259e03442a8113180f0e4630c5.pt'</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'mel_len'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">603</span><span style=\"font-weight: bold\">}</span>,\n",
       "    <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'feature_path'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'uttr-787eb20b357a4e2ebfb267611efbf92a.pt'</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'mel_len'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">584</span><span style=\"font-weight: bold\">}</span>,\n",
       "    <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'feature_path'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'uttr-863c9d7ab6bd4a70938481549e426fdc.pt'</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'mel_len'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">457</span><span style=\"font-weight: bold\">}</span>,\n",
       "    <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'feature_path'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'uttr-a9145f8b565f4d18894bce985b9ee3af.pt'</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'mel_len'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">648</span><span style=\"font-weight: bold\">}</span>\n",
       "<span style=\"font-weight: bold\">]</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "matedata_js\u001b[1m[\u001b[0m\u001b[32m'speakers'\u001b[0m\u001b[1m]\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'id00559'\u001b[0m\u001b[1m]\u001b[0m\u001b[1m[\u001b[0m:\u001b[1;36m5\u001b[0m\u001b[1m]\u001b[0m=\n",
       "\u001b[1m[\u001b[0m\n",
       "    \u001b[1m{\u001b[0m\u001b[32m'feature_path'\u001b[0m: \u001b[32m'uttr-2918eae600684146903d49f02275cb94.pt'\u001b[0m, \u001b[32m'mel_len'\u001b[0m: \u001b[1;36m397\u001b[0m\u001b[1m}\u001b[0m,\n",
       "    \u001b[1m{\u001b[0m\u001b[32m'feature_path'\u001b[0m: \u001b[32m'uttr-ba3892259e03442a8113180f0e4630c5.pt'\u001b[0m, \u001b[32m'mel_len'\u001b[0m: \u001b[1;36m603\u001b[0m\u001b[1m}\u001b[0m,\n",
       "    \u001b[1m{\u001b[0m\u001b[32m'feature_path'\u001b[0m: \u001b[32m'uttr-787eb20b357a4e2ebfb267611efbf92a.pt'\u001b[0m, \u001b[32m'mel_len'\u001b[0m: \u001b[1;36m584\u001b[0m\u001b[1m}\u001b[0m,\n",
       "    \u001b[1m{\u001b[0m\u001b[32m'feature_path'\u001b[0m: \u001b[32m'uttr-863c9d7ab6bd4a70938481549e426fdc.pt'\u001b[0m, \u001b[32m'mel_len'\u001b[0m: \u001b[1;36m457\u001b[0m\u001b[1m}\u001b[0m,\n",
       "    \u001b[1m{\u001b[0m\u001b[32m'feature_path'\u001b[0m: \u001b[32m'uttr-a9145f8b565f4d18894bce985b9ee3af.pt'\u001b[0m, \u001b[32m'mel_len'\u001b[0m: \u001b[1;36m648\u001b[0m\u001b[1m}\u001b[0m\n",
       "\u001b[1m]\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">(</span>mel_len, n_mels<span style=\"font-weight: bold\">)</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">torch.Size</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">397</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">40</span><span style=\"font-weight: bold\">])</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m(\u001b[0mmel_len, n_mels\u001b[1m)\u001b[0m=\u001b[1;35mtorch.Size\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m397\u001b[0m, \u001b[1;36m40\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "tensor([[  5.9532,   5.7866,   4.2582,  ...,  -7.4547,  -7.9219,  -2.2850],\n",
       "        [  5.6307,   4.8402,   1.5912,  ...,  -6.9332,  -7.1366,  -4.4466],\n",
       "        [  5.2866,   5.0213,   2.9144,  ...,  -7.2181,  -7.3338,  -4.4947],\n",
       "        ...,\n",
       "        [-17.3427, -17.0293, -16.7071,  ..., -18.6104, -18.7724, -18.9308],\n",
       "        [-20.7233, -20.7233, -20.7233,  ..., -20.7233, -20.7233, -20.7233],\n",
       "        [-20.7233, -20.7233, -20.7233,  ..., -20.7233, -20.7233, -20.7233]])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查看数据\n",
    "data_dir = '../input/ml2022spring-hw4/Dataset'\n",
    "\n",
    "# mapping.json\n",
    "map_ = Path(data_dir) / 'mapping.json'\n",
    "map_js = json.load(map_.open())\n",
    "cs.print('mapping.json | keys = ', map_js.keys())\n",
    "# metadata.json \n",
    "matedata_ = Path(data_dir) / 'metadata.json'\n",
    "matedata_js = json.load(matedata_.open())\n",
    "cs.print('metadata.json | keys = ', matedata_js.keys())\n",
    "cs.print(\"matedata_js['n_mels']=\", matedata_js['n_mels'])\n",
    "cs.print(\n",
    "    \"matedata_js['speakers']['id00559'][:5]=\", \n",
    "    matedata_js['speakers']['id00559'][:5]\n",
    ")\n",
    "\n",
    "mel = torch.load(os.path.join(data_dir, 'uttr-2918eae600684146903d49f02275cb94.pt'))\n",
    "cs.print(f'(mel_len, n_mels)={mel.shape}') \n",
    "mel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-18T14:20:36.452953Z",
     "iopub.status.busy": "2022-12-18T14:20:36.452508Z",
     "iopub.status.idle": "2022-12-18T14:20:36.468113Z",
     "shell.execute_reply": "2022-12-18T14:20:36.466683Z",
     "shell.execute_reply.started": "2022-12-18T14:20:36.452901Z"
    }
   },
   "outputs": [],
   "source": [
    "class myDataset(Dataset):\n",
    "    def __init__(self, data_dir, segment_len=128):\n",
    "        super(myDataset, self).__init__()\n",
    "        self.data_dir = data_dir\n",
    "        self.segment_len = segment_len\n",
    "        # 加载演讲者和id编码的映射.\n",
    "        mapping_path = Path(data_dir) / 'mapping.json'\n",
    "        mapping = json.load(mapping_path.open())\n",
    "        self.speaker2id = mapping['speaker2id']\n",
    "        \n",
    "        # 加载训练数据的源数据(特征文件， 演讲者)\n",
    "        metadata_path = Path(data_dir) / 'metadata.json'\n",
    "        metadata = json.load(metadata_path.open())['speakers']\n",
    "        \n",
    "        # 获取总演讲者数\n",
    "        self.speaker_num = len(metadata.keys())\n",
    "        self.data = []\n",
    "        for speaker, utt in metadata.items():\n",
    "            for utt_i in utt:\n",
    "                self.data.append([utt_i['feature_path'], self.speaker2id[speaker]])\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        feat_path, speaker = self.data[index]\n",
    "        # 载入经过预处理的mel图谱特征(mel-spectrogram)\n",
    "        mel = torch.load(os.path.join(self.data_dir, feat_path))\n",
    "        # 分割 mel-specrogram\n",
    "        if len(mel) > self.segment_len:\n",
    "            # 开始的位置为随机\n",
    "            start = random.randint(0, len(mel) - self.segment_len)\n",
    "            # 切分语音\n",
    "            mel = torch.FloatTensor(mel[start: start + self.segment_len])\n",
    "        else:\n",
    "            mel = torch.FloatTensor(mel)\n",
    "        # 将speaker 转成long格式便于后续计算loss\n",
    "        speaker = torch.FloatTensor([speaker]).long()\n",
    "        return mel, speaker\n",
    "    \n",
    "    def get_speaker_number(self):\n",
    "        return self.speaker_num\n",
    "\n",
    "\n",
    "class InferenceDataset(Dataset):\n",
    "    def __init__(self, data_dir):\n",
    "        super(InferenceDataset, self).__init__()\n",
    "        test_path = Path(data_dir) / 'testdata.json'\n",
    "        metadata = json.load(test_path.open())\n",
    "        self.data_dir = data_dir \n",
    "        self.data = metadata['utterances'] \n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        utt = self.data[index]\n",
    "        feat_path = utt['feature_path']\n",
    "        mel = torch.load(os.path.join(self.data_dir, feat_path))\n",
    "        return feat_path, mel\n",
    "    \n",
    "    \n",
    "def inference_collate_batch(batch):\n",
    "    feat_path, mels = zip(*batch)\n",
    "    return feat_path, torch.stack(mels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##  &#x2728;  **Transformer模型**\n",
    "<font color=darkred><b>***TODO***: encode改用Conformer</font></b>  \n",
    "<font color=darkred><b>***TODO***: 增加Self-Attention Pooling Layer</font></b>  \n",
    "\n",
    "- 可以参考[https://github.com/sooftware/conformer](https://github.com/sooftware/conformerhttps://github.com/sooftware/conformer)\n",
    "- self-attetion & multi-self-attention & transformer block可以看李老师的视频\n",
    "    - [B站视频 第五讲 Transformer-2](https://www.bilibili.com/video/BV1m3411p7wD?p=33&vd_source=f209dda877a0d7be7d5309f93b340d6f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-18T14:20:36.470288Z",
     "iopub.status.busy": "2022-12-18T14:20:36.469802Z",
     "iopub.status.idle": "2022-12-18T14:20:36.483013Z",
     "shell.execute_reply": "2022-12-18T14:20:36.481897Z",
     "shell.execute_reply.started": "2022-12-18T14:20:36.470252Z"
    }
   },
   "outputs": [],
   "source": [
    "class Classifier(nn.Module):\n",
    "    def __init__(self, input_dim=40, d_model=80, n_spks=600, dropout=0.1):\n",
    "        super(Classifier, self).__init__()\n",
    "        self.pre_net = nn.Linear(input_dim, d_model)\n",
    "        # TODO:\n",
    "        #   尝试改变Transformer， 改成Conformer.\n",
    "        #   https://arxiv.org/abs/2005.08100 \n",
    "        self.encoder_layer = nn.TransformerEncoderLayer(\n",
    "            d_model=d_model,  # self_attn [Q, K, V] shape=(d_model*3, d_model)\n",
    "            dim_feedforward=256,\n",
    "            nhead=2, \n",
    "            batch_first=True,\n",
    "            activation='gelu'\n",
    "        )\n",
    "        self.pred_layer = nn.Sequential(\n",
    "            nn.Linear(d_model, d_model),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(dropout),\n",
    "            nn.Linear(d_model, n_spks)\n",
    "        )\n",
    "    \n",
    "    def forward(self, mels):\n",
    "        # out: (batch_size, length, d_model)\n",
    "        out = self.pre_net(mels)\n",
    "        out = self.encoder_layer(out)\n",
    "        # mean pooling\n",
    "        stats = out.mean(dim=1)\n",
    "        return self.pred_layer(stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-18T14:20:36.485018Z",
     "iopub.status.busy": "2022-12-18T14:20:36.484618Z",
     "iopub.status.idle": "2022-12-18T14:20:36.493162Z",
     "shell.execute_reply": "2022-12-18T14:20:36.492134Z",
     "shell.execute_reply.started": "2022-12-18T14:20:36.484977Z"
    }
   },
   "outputs": [],
   "source": [
    "# x = torch.randn(1, 100, 40).requires_grad_(True)\n",
    "# model_plot(Classifier, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## &#x1F526;**学习率设置**\n",
    "- 对于transformer结构, 学习率的设计和CNN有所不同\n",
    "- 一些相关工作表明在训练前期逐步增加学习率（Warm up）有利于模型训练transformer.\n",
    "- 按照`plot_lr`设计一个Warm up的学习变化架构\n",
    "  - 设置学习率在 0到优化器设置的学习率的区间\n",
    "  - 在初期（Warmup period）学习率从零增长到0 to 初始学习率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-18T14:49:03.261176Z",
     "iopub.status.busy": "2022-12-18T14:49:03.260475Z",
     "iopub.status.idle": "2022-12-18T14:49:03.525534Z",
     "shell.execute_reply": "2022-12-18T14:49:03.524535Z",
     "shell.execute_reply.started": "2022-12-18T14:49:03.261138Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAEmCAYAAACEQCxyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAAsTAAALEwEAmpwYAAA8XElEQVR4nO3dd3hVVdb48e9KJxBCC70qoAaUYqSIbcaGFWdsiAUVwe44TtN3fjOvOuqMzmtvAwh2mlgGxzqKBaUGxEKTSBGQLl0gbf3+2Dt4jSmXkOTcm7M+z3Mf7j3trpOEvc7Z+5x1RFUxxhgTTglBB2CMMSY4lgSMMSbELAkYY0yIWRIwxpgQsyRgjDEhZknAGGNCzJKAiSsicoKIrK7iui1E5GMR2SEi91d3bNVFRHaKyEFBx2HCwZKAqRLfUJW8ikVkd8Tni4OOrxwjgE1AQ1X9XemZIvKMiNxV+2H9lKo2UNVl1b1dEbldRAr872iriEwXkf77sb6KSOfqjssEy5KAqRLfUDVQ1QbAt8BZEdNeLFlORJKCi/JnOgALNcA7JGPg5zHR/86aAR8ALwUcjwmYJQFTrUq6a0TkTyKyDnhaRBJE5FYR+UZENovIJBFp4pfv6I8wh4rItyKySUT+HLG9ev4IfYuILASOquT7jxaROSKyzf97tJ/+DDAU+KM/Ej5pP/frTBGZH3EEfUTEvJJ92yEiC0XkVxHzLheRT0XkQRHZDNzu9+dxEXnDrzNLRA6OWGffEXcUy54iIkv8/j4hIh+JyFWV7Y+qFgIvAm1EJMtvq4+IzPD7uFZEHhORFD/vY7/q5/7nd2FlPxcTHywJmJrQEmiCO/IeAdwInAMcD7QGtgCPl1rnGOAQ4ETgryJymJ/+v8DB/nUqriEvk08sbwCPAE2BB4A3RKSpql6Oa/Tu82cr70W7MyLSCxgLXO23OxKYIiKpfpFvgGOBTOAO4AURaRWxib7AMqAFcLefNtgv2xjIi5heljKXFZFmwGTgNh/XEuDoKPcpBbgM2Iz7fQAUAb/FnSX0x/0urgNQ1eP8Mj38z29iFD8XEwcsCZiaUAz8r6ruVdXdwDXAn1V1taruBW4HzivVNXKHqu5W1c+Bz4EefvoFwN2q+r2qrsI18OU5A1iqqs+raqGqjgcWA2cd4P6MAEaq6ixVLVLVZ4G9QD8AVX1JVb9T1WJVnQgsBfpErP+dqj7qY9rtp72qqrMjjsh7VvD95S17OrBAVV/x8x4B1lWyLxeIyFZgNzAcOM+vi6rOVdWZPs4VuEb9+Kr+XEx8sCRgasJGVd0T8bkD8KrvMtgKLMIddbaIWCay8foBaODftwZWRcxbWcH3ti5j/kqgTfShl6kD8LuS+P0+tPPfh4hcFtElshXojjuaLrGq9AYpf3/LEtXPxo91VHbl1CRVbYT72X8FHFkyQ0S6ish/RGSdiGwH7im1H6VV+HMx8cGSgKkJpQdeVwGnqWqjiFeaqq6JYltrcQ1LifYVLPsdrmGK1B6I5nsqsgp3NhIZf7qqjheRDsBo4AagqW9gvwIkYv2aGoheC7Qt+SAiEvm5Iqq6CXckf3tE19WTuDOnLqraEPgffrofpZX7c6nCvpiAWBIwteFfwN2+wUREskRkUJTrTgJuE5HGItIWN75QnjeBriIyRESS/OBlNvCf/Yg1UUTSIl4puEb+GhHpK059ETlDRDKA+rhGfqPftytwZwK14Q3gcBE5x3etXY8bj4mKqi4B3gH+6CdlANuBnSJyKHBtqVXWA5H3L1T0czFxwpKAqQ0PA1OAd0VkBzATN1gajTtwXTrLgXeB58tbUFU3A2cCv8MNeP4RONMf9UbrVlx/eclrqqrm4vrPH8MNouYBl/vvXAjcD8zANZKHA5/ux/dVmd+v84H7cPubDeTi+uWj9U9ghIg0B34PDAF24Br4iaWWvR141nf9XFDRz8XED7GHyhhTN4hIAm5M4GJV/SDoeEx8sDMBY+KYiJwqIo38ZZklffgzAw7LxBFLAsbEt/64+xQ24S6FPSfiMlRjKmXdQcYYE2J2JmCMMSFmScAYY0LMkoAxAfIF5j4JOo4SIpIiIpNFZIUvZHdCqfkiIveKKwS42b+XiPk9RWSuiPzg/+0Z7bomGJYETJ0mwZdujkefAJdQdh2iEbhigD2AI3CD0VfDvqJ0/wZewBW7exb4d0kl0orWNQFSVXvZC1UFWIG7YegLYBvuZqE03A1An5RaVoHO/v0zwBPAW8BO3M1SLYGHcDcRLQZ6VfLdVwCvR3xeCrwU8XkV0NO/f9h/3g7MBY6NWO52XGXNF/z8q4APgbuA6T6+13FVL1/0y8wBOvr1O/p9S4rY5ofAVf795X7/HvM/o8XAiVH8bC/HVRLdgbvx7WLgMGAPro7STmCrXzYV+D/ccxrW4+64rufnnYC7F+B/cFcErcDdF1DyPacDC/33rAF+fwB/D6uBE0pNmw6MiPg8DJjp35/iv1Mi5n8LDKxsXXsF97IzAVPaBcBAoBPuaO3y/Vjv/+EKju3F3UE7z3+ejCvrXJGPgGPFPXugNZCCu/wRcY9abIBLTuAa7Z64ctXjgJdEJC1iW4P8dzbCNfTgyjFfiismd7CP72m/jUW4ktXR6ou7LLOZX+8VX8a6TCJSH1fh8zRVzcCVe56vqotwFVZnqCvP3Miv8g+gq9/Hzj7mv0ZssqX/7ja40tqjROQQP28McLX/nu7AVB9D+8hCb2W8hkS5791wVV5LfO6nlcz7Qn0L731Ran5565qAWBIwpT2irizy97gj5p5RrvequlLEe4BXgT2q+pyqFuHOKHpVtLK6xynu8N93HK6mzXe+hs3xwDRVLfbLvqCqm9WVPL4fd+R8SMTmZqjqa+pKO5dcM/+0qn6jqttwZyzfqOp76soov1RZfKVsAB5S1QJ1paOX4MpYV6QY6C4i9VR1raouKGsh30c+AvituvLZO3DVPAeXWvQv6kp1f4SrIXSBn14AZItIQ1XdoqrzAFT1W/1pobfSr3FR7nsD3BlQiW1AAx936Xkl8zOiWNcExJKAKW1/ShxHWh/xfncZn6PZzke47o7j/PsPcQngeP8ZABH5vYgsEvc0ra24h7lUVrq5OuIrsabU0e5KKiifrKq7gAtxR/1rxT0l7NByFs8C0oG5EeWZ3/bTS2zx2yzr+8/FdQmtFPeUsaifIRylnUDDiM8NgZ3+51F6Xsn8HVGsawJiScBEYxeuYQJARKKuVLmfSpLAsf79R5RKAiJyLK4w3AVAY9+Fso3qK91c0rimR0wrvb9tSh29tseVsS6Xqr6jqicDrXDjCKPLiXUTLil1izhKz1T3XOASjX0X08++X1XnqOogoDnwGq4Ka0l30M4KXhdXFH+EBfz4wB/8+wUR844o9bM5otT88tY1AbEkYKLxOdDNX/6Xhht8rQkfAb/ADYKuBqbhxieaAp/5ZTKAQlzp5iQR+Ss/P/qsMlXdiBvcvEREEkXkStwYQqTmwE0ikiwi5+MGeN8sb5si0kJEBvmGey/uiLjYz14PtC25gsZ3eY0GHvSVPRGRNiJyaqnN3uEv5zwWVzn1Jf/5YhHJVNUC3KB3SRfat37cobxXydgJIpIaMcaSIq6kdknD/hxwi4+pNa5i6zN+3oe4Qe6b/DZu8NOnRrGuCYglAVMpVf0auBN4D3fVTo1c1+6/Zyeu8UdVt+OuqPnUjy2AGyt4G/ga1w2yh7K7fw7EcOAPuPLM3XBXtUSaBXTBHbXfjXtE4+YKtpcA3II7Wv8ed2ZTUqt/Ku5oeJ2IlJS8/hOuLPNMcU/4eo+fjnmsw1119R1u4PsaVV3s510KrPDrXYO7Cml/LcGdjbTB/bx38+PDekbixoq+xD085w0/DVXNx10CehmwFbgSV8sov7J1TXCsdpAx+0FELsddLnpMQN9/AvCCqkb1BDFjKmNnAsYYE2J2N6WpNSLSHncjU1myVfXb2oynuonIznJmnaaq02o1GGOiZN1BxhgTYtYdZIwxIWZJwMQ9EfmXiPylupc1JgwsCZhAiStZfNKBbENVr1HVv1X3stVNRJ4RkbuC+O6y+IQYecPYXhHZETG/iYi8KiK7RGRl6fpCIjLET98lIq9F1k86kHVN7bIkYGKaWCnoGuMT4r4bxoDxuDpKJR4H8oEWuPsNnhSRbgD+35G4+xJa4EqMPFFN65raVBulSu0Vuy+CLR/9PO6O1t1+G3/kx1LOw3BliD/2y76Eu0lqG/AxrqwCEbHc5d+fgCuB/Dtcobe1wBVVXLYp7uamknLTd5X+mZSxTwI86Le3HXdjVHdcUbgCXMO4E182G1fz52XcHdDLgZsitnU7rhrqRFz9nXlAj4j5f8Ld3bwDd4NXpSWtK4i7vt/O8RGf84GupX5f//Dv7wHGRcw72C+fcSDrBv3/IYwvOxMwEFD5aFW9FNfQn6XuaPS+iNnH48oxlJRLeAt3l25z/x0vUr6WuKJybXDJ5HERaVyFZR/H1RJqiSvZPLSi/fFOwRXA6+q3ewGwWVVH+Zjv8/t6logk4JLM5/77TwRuLlUiYhAuAZaUzX7Nl6s4BLgBOEpd2ehTcQm9pKulorLR7cuI+1xcIvrYf+4KFKq7i7tE6bLR+8pCq+o3+Ib/ANc1tcySgIGAykdX4nZV3aW+FLSqjlXVHaq6F3eE3ENEMstZtwC4U12p5zdxR96H7M+yIpKIaxj/V1V/UNWFuCdlVaYAdzR8KO4S7EWquracZY8CslT1TlXNV1dOezQ/LRs9V1Unq6sF9ADuLK0frkZPKq5sdLKqrvCNKao6TisuG13W/RhDgedUteSa8Qa4M5lIFZWFjpx/IOuaWmZJwECw5aPLs68ekC/k9g8R+cbXxFnhZzUrc0135F0Y8bmifSpv2SzczZSRdYkqrVGkqlNxTx17HNggIqNEpLwCdx2A1pFH6bgnhrUo6zvVFZdbDbRW1TzgZlxC3CAiE3xRtv3mzwxOwBV4K7G/ZaEj5x/IuqaWWRIw5amt8tHl3a0YOX0IrlvkJFwXS8eSsGooJnBdI4VAZI2edtGsqKqPqOqRQDaui+MPJbNKLboKWF7qKD1DVU8v6zt991FbfiwbPU5dDaMOftv3+uUurqRsdOnuoEtxRfqWRUz7GleltUvEtNJlo/eVhRb39LdUv96BrGtqmSUBU57aKh+9HjiokmUycGMOm3GJ6Z4aimUf36X1CnC7iKSLewjMZZWtJyJHiUhfEUnGJdI9/LRsdOS+zgZ2iMifRKSeP+PpLiJHRSxzpIj82l8ldTPu5zBTRA4RkV+KSKr/jt38WDb6Ra24bHTp7qDLKFXSWd1Da14B7hSR+iIyAJeIn/eLvAicJSLHiiuRfSfwiu+yq/K6lf18TfWzJGDKpLVUPhr4O/D/fHfI78tZ5jlc2eg1uNpDM2soltJuwJ15rMM1YONxjXBFGuL69bfgYt4M/NPPG4Prw98qIq/5RHMmbgxmOa409VP+O0v8G/dUsi24I/Zf+/GBVNyziDf5+JoDt+3vDop78lhbfnppaInrgHq4K53GA9eqfyym//caXIO+AZeor6umdU0tstpBxkRJRO4FWqpqNFcJVcf33Y67JPeS2vg+E052JmBMOUTkUBE5Qpw+uEtIXw06LmOqk92NaWqUxHf56AxcV0ZrXH/+/cC/xT3S8a2yVtCfPgvYmJhn3UHGGBNi1h1kjDEhFlfdQc2aNdOOHTsGHYYxxsSNuXPnblLVrPLmx1US6NixI7m5uUGHYYwxcUNEVlY037qDjDEmxCwJGGNMiFkSMMaYELMkYIwxIWZJwBhjQiyqJCAiA0VkiYjkicitZcxPFZGJfv4sEenopzcVkQ98+drHSq1zpIh86dd5RERqsiywMcaYMlSaBPwTlh4HTsPVR79IRLJLLTYM2KKqnXHPV73XT98D/AX3DNvSngSG4x4Z2AX3eENjjDG1KJr7BPoAeSUPnBCRCbja4JH1YAbxY735ycBjIiK+rvgnItI5coMi0gpoqKoz/efngHMopx5LdVuzdTeTc1eTmAApSQmkJCaQnJRARloyjdOTaZyeQuP6KTStn0JacmJthGSMMYGIJgm04aeP1VsN9C1vGVUtFJFtQFNcrfPytrm61DbblLWgiIwARgC0b1/W87H336Q5q3j4/aVRLdu0fgptG9ejbeN02jSux0HN6tO1ZQZdW2TQIDWu7rUzxpififlWTFVHAaMAcnJyqqXa3d7CYpIThQV3DCS/qJj8QvfasaeA73fls+WHArb+kM+mnXtZs3U3q7fsZtHa7fx30XryC4v3badt43oc2jKDHm0b0btDY3q0a2SJwRgTV6Jpsdbw02ertvXTylpmtX8MXibuiUoVbTPy2a1lbbPGFBQVk5KY4LqCkhLcM5qAlplpFa5XXKys2bqbxet28PX6HSxZt4OFa7fz3qINACQIdG2RwVEdm3BMl2b0P7gpDdOSa3p3jDGmyqJJAnOALiLSCddQD8Y9+DvSFGAoMAM4D5iqFdSoVtW1IrJdRPoBs3DPOH20CvFXSUFRMclJ+391bEKC0K5JOu2apHNydot907ftLmD+qq3MW7mFed9u4eV5q3l+5koSE4QebTM5pksWJx3WnMPbZGIXQRljYkmlScD38d8AvAMkAmNVdYGI3AnkquoU3LNTnxeRPOB7XKIAQERW4J67miIi5wCnqOpC3DNFn8E9h/QtamlQGKCgSElOrL5bJDLrJXN81yyO7+oK9eUXFvPZt1v4JG8T05Zu4rGpS3nk/aW0zkzjlG4tGdi9JUd1bEJigiUEY0yw4uqhMjk5OVodVUR//9LnTM/bxPTbTqyGqCq3ZVc+7y/ewNtfrWPa0o3sLSymaf0UzurRmnN7t6V7m4Z2hmCMqREiMldVc8qbH8pRzKp2B1VV4/opnHdkW847si279hby0dcbeeOLtYyb/S3PTF9Bl+YN+HXvtvyqV5tKxyWMMaY6hTcJVGN30P6on5rE6Ye34vTDW7HthwLe+HItL89bzb1vL+af7yzmpMNacFn/jgzo3NTODowxNS6kSaB6xwSqKjM9mSF92zOkb3tWbNrFhDmrmJS7incXruegrPpc2q8D5x7Z1q4wMsbUmOBbwgC4M4HYOsru2Kw+t552KNNv/SUPXNCDzHrJ3PH6Qvrf8z53v7GQddv2BB2iMaYOCumZQHDdQZVJS07k173b8uvebflqzTaemraMsZ+u4JnpKzinZxuuPv4gOjfPCDpMY0wdEZstYQ0rKNSYOxMoS/c2mTw0uBcf/v4EhvRpz+tffMdJD3zM1c/nsmjt9qDDM8bUAeFMAsWxeyZQlnZN0rljUHc+/dMvuemXnZn+zWZOe3ga14+bR96GHUGHZ4yJY/HTElajWO4OqkjTBqnccsohfPLHX3LDLzrz4eINnPLgx9wycT4rN+8KOjxjTByKv5awGsRLd1B5MtOT+f2phzDtT79k+LEH8eZXaznx/o+4fcoCtuzKDzo8Y0wcCWcSiNMzgdKa1E/httMP4+M//IILjmrHczNWcPw/P+Cpact+Uu3UGGPKE/8tYRXk+yqidUXzhmnc86vDees3x9GrfWPuemMRpzz4EW9/tZZ4KgtijKl9dacl3A+FMXKzWHU7pGUGz17Zh2euOIqUpASueWEel46ZzbKNO4MOzRgTo+peSxiFgqJikuJ4TKAyJxzSnDdvOpY7B3Xj89VbGfjQNO5/dwl7CoqCDs0YE2NCmQTy68iYQEWSEhO4rH9Hpv7uBM48ohWPTs3j5Ac/4v1F64MOzRgTQ+p2S1iOgqJi90SxEMjKSOWBC3syfng/UpMSGfZsLtc8P5cNO6wMhTEmpEnAjQnU3e6gsvQ/uClv3nQsfxx4CFOXbODkBz7m5bmrbeDYmJALXRIoLlYKi5WkhNDtOilJCVx3Qmfe+s2xdG7egN+99DlXPjOHtdt2Bx2aMSYgoWsJC4rd9fNh6Q4qy8FZDZh0dX/+96xsZi77nlMe+Jjxs7+1swJjQih0LWFBkWvowtYdVFpignDFgE68c/NxdG+TyW2vfMkVz8yxsQJjQiZ0SaCwyJ0J1PWrg6LVvmk644b35c5B3ZjxzWYGPjSNdxesCzosY0wtCV1LmG9J4GdEhMv6d+SNm46hdaM0Rjw/l1tf/oJdewuDDs0YU8NC1xJad1D5OjfP4JVrB3DtCQczMXcVpz8yjXnfbgk6LGNMDQpfEii0M4GKpCQl8KeBhzJheD8Ki5Tz/zWDJz/8huJiGzQ2pi4KXUtYYN1BUel7UFPeuvlYBnZryb1vL+bKZ+fwvZWpNqbOCV1L+GN3UOh2fb81TEvmsSG9+NugbkzP28zpD09jzorvgw7LGFONQtcS/ngmYGMC0RARLu3fkVeuO5q05AQGj5rJEx/mWfeQMXVEiJNA6Hb9gHRvk8nrNx7Dad1bct/bS7jimTn2FDNj6oDQtYR2iWjVZaQl8+hFvbjrnO7M+GYzZz32CV+t2RZ0WMaYAxC6lrDQjwmkJFl3UFWICJf068DEq93VQ+c+OZ1XP1sddFjGmCoKXRIo6Q4KYwG56tSrfWNev/EYerRrxG8nfs7tUxbs+9kaY+JH6FpCGxOoPlkZqbx4VV+uHNCJZ6av4OKnZrFxx96gwzLG7IeoWkIRGSgiS0QkT0RuLWN+qohM9PNniUjHiHm3+elLROTUiOm/FZEFIvKViIwXkbRq2aNK5Ft3ULVKTkzgr2dl89CFPfli9VbOfHQa81dtDTosY0yUKk0CIpIIPA6cBmQDF4lIdqnFhgFbVLUz8CBwr183GxgMdAMGAk+ISKKItAFuAnJUtTuQ6JercVZArmac06sNL197NMmJCVw4cgavf/5d0CEZY6IQTUvYB8hT1WWqmg9MAAaVWmYQ8Kx/Pxk4UUTET5+gqntVdTmQ57cHkATUE5EkIB2olVbDuoNqTrfWmfz7+gEc0TaTG8d/xoP//druJzAmxkXTErYBVkV8Xu2nlbmMqhYC24Cm5a2rqmuA/wO+BdYC21T13bK+XERGiEiuiORu3LgxinArVtIdlGQ3i9WIpg1SeeGqvpzbuy0Pv7+UGyd8xu78oqDDMsaUI5DDYRFpjDtL6AS0BuqLyCVlLauqo1Q1R1VzsrKyDvi7SwrIpdiZQI1JTUrk/84/gttOO5Q3v1zLhaNmsH67PazGmFgUTUu4BmgX8bmtn1bmMr57JxPYXMG6JwHLVXWjqhYArwBHV2UH9pd1B9UOEeHq4w9m1KU55G3YydmPfcKXq+3GMmNiTTQt4Rygi4h0EpEU3ADulFLLTAGG+vfnAVPVPbB2CjDYXz3UCegCzMZ1A/UTkXQ/dnAisOjAd6dyhcVWQK42nZzdgpevPZqkhATOHzmd/y5cH3RIxpgIlbaEvo//BuAdXEM9SVUXiMidInK2X2wM0FRE8oBbgFv9uguAScBC4G3gelUtUtVZuAHkecCXPo5R1bpn5cgvtAJyte2wVg157foBdG2RwdXP5/L8jBVBh2SM8cQdsMeHnJwczc3NPaBt3Pf2YkZ9vIy8e06vpqhMtH7IL+TGcZ/x/uINXH38Qfzp1ENJSLBkbExNEpG5qppT3vzQ9YkUFBVbV1BA0lOSGHnpkVzSrz0jP1rGbybOZ2+hXTlkTJCSgg6gthUUqXUFBSgpMYG/DepOm0bp3Pv2YtZv38PoS3PITE8OOjRjQil0h8R2JhA8EeHaEw7m4cE9mf/tVs7913RWff9D0GEZE0qhaw0tCcSOQT3b8NywPmzYvodfPzmdBd/ZJaTG1LbQtYYFRUqyFY+LGf0OauovIRUGj5zJrGWbgw7JmFAJYRKwM4FY06VFBpOvPZqshqlcNna23UtgTC0KXWtYUFRsJSNiUJtG9Zh8zdEc2jKDa16Yy0u5qypfyRhzwELXGhYUqRWPi1FN6qcwbng/jj64KX+Y/AUjP/om6JCMqfNCmASsOyiW1U9N4qmhOZxxRCv+/tZi/v7mIuLphkZj4k3o7hPIL7QkEOtSkxJ5ZHAvGqcnM/LjZXy/K5+///pwkuz3Zky1C10SKCxW6iUnBh2GqURigvC3Qd1pWj+Vh99fyrbdBTw6pBepSfa7M6Y6he7QqqCo2MYE4oSI8NuTu3L7Wdm8u3A9Vz2baw+oMaaahS4JWHdQ/Ll8QCfuO+8IPs3bxNCxs9mxpyDokIypM0LXGtolovHpgpx2PDy4F/O+3cIlY2az9Yf8oEMypk4IXWtYWGwF5OLVWT1a869LjmTRd9sZPGomG3fsDTokY+Je6JJAgXUHxbWTslsw9vKjWLn5By4cNYO123YHHZIxcS10rWF+kdqlhnHumC7NeG5YHzZu38v5/5rBt5utAqkxVRW61tCNCVh3ULw7qmMTxg3vx869hZw/cjp5G3YGHZIxcSmUScC6g+qGw9tmMnFEf4qK4cKRM1i8bnvQIRkTd0LXGhYWKclJodvtOuuQlhm8dE1/khMTuGjUTBZ+Z4nAmP0RqtZQVckvKibZHm5ep3RqVp+JV/ejXnIiQ56ayVdr7OE0xkQrVEmgsNgVIrPuoLqnQ9P6TBjRn/opSQwZPZMvV1siMCYaoWoNC4qKAaw7qI5q3zSdCSP60bBeMkOemsn8VVuDDsmYmBeq1rCgyM4E6rp2TVwiaJSezKVPzWLet1uCDsmYmBaq1nDfmYBdIlqntW2czsQR/WnSIIXLxsxm7srvgw7JmJgV0iQQqt0OpdaN6jFxRH+yMlK5bMxs5qywRGBMWULVGhYUWndQmLTMTGPCiH60yExj6NjZzFq2OeiQjIk5oWoNC4qtOyhsWjRMY8LwfrTKTOPyp+cw/ZtNQYdkTEwJVxLw3UFWSjpcmjdMY8KI/rRtXI9hz+Qy084IjNknVK1hSXeQFZALn6yMVMYN70frRmlc+cwcGyMwxouqNRSRgSKyRETyROTWMuanishEP3+WiHSMmHebn75ERE6NmN5IRCaLyGIRWSQi/atljyqQb1cHhVpWRirjh/ejZcM0Lh87m7kr7fJRYypNAiKSCDwOnAZkAxeJSHapxYYBW1S1M/AgcK9fNxsYDHQDBgJP+O0BPAy8raqHAj2ARQe+OxWz7iDTvGEa44b3IysjlaFjZ/OZ3UdgQi6a1rAPkKeqy1Q1H5gADCq1zCDgWf9+MnCiiIifPkFV96rqciAP6CMimcBxwBgAVc1X1a0HvDeVKCy5WczuGA61lplpjB/Rjyb1U7hs7Gy+WL016JCMCUw0rWEbYFXE59V+WpnLqGohsA1oWsG6nYCNwNMi8pmIPCUi9cv6chEZISK5IpK7cePGKMItX8mZQJIVkAu9Vpn1GD+iH5n1krnkqVlWdM6EVlCHxElAb+BJVe0F7AJ+NtYAoKqjVDVHVXOysrIO6Evz7WYxE6FNo3qMH96PjLRkLn5qFgu+s0Rgwiea1nAN0C7ic1s/rcxlRCQJyAQ2V7DuamC1qs7y0yfjkkKN2jcmYN1BxmvXJJ3xw/tRPyWRS56aZQ+mMaETTWs4B+giIp1EJAU30Dul1DJTgKH+/XnAVFVVP32wv3qoE9AFmK2q64BVInKIX+dEYOEB7kulCq2AnClD+6bpjBvej9SkRC4ePYuv1+8IOiRjak2lraHv478BeAd3Bc8kVV0gIneKyNl+sTFAUxHJA27Bd+2o6gJgEq6Bfxu4XlWL/Do3Ai+KyBdAT+CeaturctgloqY8HZvVZ9zwviQmCENGzyRvgyUCEw7iDtjjQ05Ojubm5lZ5/RdnreTPr37FrP85kRYN06oxMlNX5G3YyeBRMxGBCSP6cXBWg6BDMuaAiMhcVc0pb36o+kUKCm1g2FSsc/MGjB/eF1XlolEzWb5pV9AhGVOjQtUa/vh4SesOMuXr0iKDF6/qR2GxMmT0TL7d/EPQIRlTY0KVBOwSUROtQ1pm8MKwvuwuKOKi0TNZvcUSgambQtUa2vMEzP7Ibt2QF4b1ZceeAoaMnsXabbuDDsmYaheq1rCgqJgEgUS7Y9hEqXubTJ4b1pctu/IZMnoW67fvCTokY6pV6JKAnQWY/dWzXSOeufIoNmzfw5DRM9m4Y2/QIRlTbULVIhYUqVUQNVVyZIcmPH1FH77buoeLn5rJ5p2WCEzdEKoWsaComCS7MshUUZ9OTRhzeQ4rN//AxU/NYsuu/KBDMuaAhS4JWHeQORBHH9yM0ZflsGzTLi4dO4ttPxQEHZIxByRULWK+JQFTDY7rmsXIS47k63U7uWzsLLbvsURg4leoWsTCIrUKoqZa/OLQ5jx+cW8WfLedK56ew869hUGHZEyVhKpFdN1BNiZgqsfJ2S149KJezF+1lSufnsMP+ZYITPwJXRJISgjVLpsadtrhrXjowp7krvyeYc/ksju/qPKVjIkhoWoR84vUni9sqt1ZPVpz/wU9mLl8MyOez2VPgSUCEz9C1SIWFhWTYt1Bpgb8qldb7jv3CKYt3cS1L8xlb6ElAhMfQpUE7BJRU5POz2nHPb86nA+WbOT6Fz8j35cuNyaWhapFzC9SkiwJmBo0pG977hzUjfcWrec3Ez7b91xrY2JVqFrEgkLrDjI177L+HfnLmdm89dU6bpn0OYWWCEwMSwo6gNpk3UGmtgw7phOFRcX8/a3FJCUI/3d+D6tea2JSqJJAYbFaEjC15urjD6awWPnnO0tIShDuPfcIEiwRmBgTqiSQX2hnAqZ2Xf+LzuQXFvPw+0tJSkzg7nO6WyIwMSVUScDuGDZBuPmkLhQUFfPEh9+QnCjccXY3ROzv0MSGECYBOxMwtUtE+MOph1BYrIz6eBlJCQn85czDLBGYmBCqJFBYZGMCJhgiwm2nHUpBUTFjP11OcpJw68BDLRGYwIUqCeQXFZOcZP/pTDBEhL+emU1BUTEjP1pGSmICvzvlkKDDMiEXqiRQUFRMshWQMwESEe48uzuFRcqjU/NISkjgNyd1CTosE2KhSQJFxUqxYt1BJnAJCcI9vzqcwmLlwfe+JilRuP4XnYMOy4RUaJJAye371h1kYkGCv2+goKiYf76zhJTEBIYfd1DQYZkQCl0SSLEzARMjEhOE+8/vQWGxcvebi0hKFK4Y0CnosEzIhCgJKABJdqOOiSFJiQk8dGFPCouKueP1hSQlJnBpvw5Bh2VCJDSHxT92B4Vml02cSE5M4NGLenPSYc35y2tfMWH2t0GHZEIkqhZRRAaKyBIRyRORW8uYnyoiE/38WSLSMWLebX76EhE5tdR6iSLymYj854D3pBIltd1tYNjEopSkBB6/uDfHd83itle/ZPLc1UGHZEKi0hZRRBKBx4HTgGzgIhHJLrXYMGCLqnYGHgTu9etmA4OBbsBA4Am/vRK/ARYd6E5Eo7DYdQfZmICJValJiYy89EgGHNyMP0z+nH/PXxN0SCYEomkR+wB5qrpMVfOBCcCgUssMAp717ycDJ4q7FXIQMEFV96rqciDPbw8RaQucATx14LtRuX3dQZYETAxLS05k9GU59O3UhN9OnM8bX6wNOiRTx0XTIrYBVkV8Xu2nlbmMqhYC24Cmlaz7EPBHoMInbojICBHJFZHcjRs3RhFu2Uq6g5KsgJyJcfVSEhkz9CiO7NCYmyZ8xttfrQs6JFOHBXJYLCJnAhtUdW5ly6rqKFXNUdWcrKysKn+nXSJq4kn91CSevqIPR7TN5Mbx83h/0fqgQzJ1VDQt4hqgXcTntn5amcuISBKQCWyuYN0BwNkisgLXvfRLEXmhCvFHrWRMwLqDTLxokJrEs1f24bBWDbn2hXl8uGRD0CGZOiiaFnEO0EVEOolICm6gd0qpZaYAQ/3784Cpqqp++mB/9VAnoAswW1VvU9W2qtrRb2+qql5SDftTroJ9VwdZd5CJHw3Tknn+yr50adGAEc/P5ZOlm4IOydQxlSYB38d/A/AO7kqeSaq6QETuFJGz/WJjgKYikgfcAtzq110ATAIWAm8D16tqUfXvRuXyi0rGBOxMwMSXzPRkXhjWl4Oa1eeq5+Yw45vNQYdk6hBxB+zxIScnR3Nzc6u07n8Xrmf4c7m8fsMxHN42s5ojM6bmbdq5l4tGzWTN1t08e2UfjurYJOiQTBwQkbmqmlPe/NAcFlsBORPvmjVI5cXhfWmZmcblY2czd+WWoEMydUD4koB1B5k41jwjjfHD+5GVkcrlY2fzxeqtQYdk4lxoWsSSAnJ2iaiJdy0apjFueD8a1U/mkqdm8dWabUGHZOJYaFrEgiK7WczUHa0b1WPcVf3ISEvmkjGzWLR2e9AhmTgVuiRg3UGmrmjXJJ1xw/uSlpTIkNEz7YzAVEloWsSS7iBLAqYu6dC0PhOv7kd6ShJDRs9k/qqtQYdk4kxoWkQrG2HqqpJE0Cg9hUuemsWcFd8HHZKJI6FpEQusgJypw9o2TmfS1f1pnpHKZWNmMz3P7iw20QlPEigZGLbHS5o6qmVmGhOu7ke7JvW44pk5VmvIRCU8SaBYSUlMwD3mwJi6qXlGGhNG9OfgrAaMeG4u/11o1UdNxcKTBAqLrXicCYUm9VMYP7wfh7XK4NoX5tqDaUyFwpMEioqteJwJjcz0ZF64qi892zXixvHzePUze2axKVtoWsX8IrXLQ02oZKQl8+yVfejbqSm3TPqciXO+DTokE4NC0yoWFBWTYt1BJmTcE8qO4tguWfzp5S8Z88nyoEMyMSY0SaCwqJjkpNDsrjH7uIfXH8lp3Vvyt/8s5IF3lxBPJeRNzQpNq1hg3UEmxFKTEnlsSG8uzGnHI1PzuH3KAoqLLREYSAo6gNqSX1Rs9wiYUEtMEP5x7uFkpicz6uNlbNtdwD/P72EHRyEXmiRQUFRMinUHmZATEW477VAy6yXzz3eWsGNPIY9f3Ju05MSgQzMBCU2rWGjdQcYALhFc/4vO3HVOd6Yu2cDQsbPZsacg6LBMQELTKuYX2c1ixkS6pF8HHh7ci7krt3DR6Jls3rk36JBMAEKTBAqKiu1MwJhSzu7RmtGX5bB0/U7OHzmD1Vt+CDokU8tC0ypaEjCmbL84tDnPD+vLxh17+fUT01n4nT2lLExC0yq6MQHrDjKmLH06NWHyNUeTmCBcMHIGn1op6tAITRLItzMBYyp0SMsMXrnuaNo0qsflT8/m3/PXBB2SqQWhaRWtO8iYyrXKrMeka/rTu31jfjNhPqM+/sbuLq7jQtMqFhRad5Ax0cisl8xzw/pwxhGtuOfNxdz5n4V2d3EdFqqbxexMwJjopCYl8ujgXrTISGPsp8vZsH0v91/Qw24qq4MsCRhjypSQIPz1rGxaZaZx95uL2LBjDyMvzaFJ/ZSgQzPVKDStYkGRWtkIY6pg+HEH8diQXny+ehvnPP4peRt2Bh2SqUahaRULrICcMVV25hGtmTCiHz/kF/KrJz61S0jrkKiSgIgMFJElIpInIreWMT9VRCb6+bNEpGPEvNv89CUicqqf1k5EPhCRhSKyQER+U217VAZVpbDYagcZcyB6t2/Mq9cNoHVmPYaOnc342faksrqg0lZRRBKBx4HTgGzgIhHJLrXYMGCLqnYGHgTu9etmA4OBbsBA4Am/vULgd6qaDfQDri9jm9WmoMhd2WDdQcYcmHZN0pl8bX8GdG7Gba98yd1vLKTIrhyKa9G0in2APFVdpqr5wARgUKllBgHP+veTgRNFRPz0Caq6V1WXA3lAH1Vdq6rzAFR1B7AIaHPgu1O2gqJiALtE1JhqkJGWzJihOQzt34HR05Zz9fNz2bW3MOiwTBVFkwTaAKsiPq/m5w32vmVUtRDYBjSNZl3fddQLmFXWl4vICBHJFZHcjRs3RhHuz5UkgaQEOxMwpjokJSZwx6Du3H5WNlMXr+fcJ6ez6nsrPhePAm0VRaQB8DJws6qWWbVKVUepao6q5mRlZVXpe/JLzgSsO8iYanX5gE48fUUfvtu6m7Me+8QGjONQNK3iGqBdxOe2flqZy4hIEpAJbK5oXRFJxiWAF1X1laoEH63CkjEB6w4yptod3zWLKTccQ/OMVC4dM4unpi2zUhNxJJokMAfoIiKdRCQFN9A7pdQyU4Ch/v15wFR1fwVTgMH+6qFOQBdgth8vGAMsUtUHqmNHKvLjmICdCRhTEzo2q88r1w3glOyW3PXGIm6Z9Dl7CoqCDstEodJW0ffx3wC8gxvAnaSqC0TkThE52y82BmgqInnALcCtft0FwCRgIfA2cL2qFgEDgEuBX4rIfP86vZr3bR9LAsbUvAapSTxxcW9+d3JXXpu/hvP+NZ01W3cHHZapRFRlI1T1TeDNUtP+GvF+D3B+OeveDdxdatonQK31zeQXulNTuzrImJqVkCDceGIXDmvVkN9OnM/Zj37Cw4N7cUyXZkGHZsoRikNjOxMwpnadlN2C124YQJP6KVw6dhYPvfe13U8Qo0LRKhYWWxIwprYdnNWAf98wgF/1bMND7y3l8qdns8keZh9zQtEq/tgdFIrdNSZmpKckcf8FPbj33MOZvfx7znhkGnNWfB90WCZCKFpFu2PYmOCICBce1Z5XrxtAveREBo+ayciP7IllsSJkSSAUu2tMTMpu3ZDXbzyGU7u14O9vLebyp+ewYceeoMMKvVC0iiUF5CwJGBOsjLRkHh/Sm7+d052ZyzZz2kPTeH/R+qDDCrVQtIolZwIpSdYdZEzQRIRL+3XgjZuOoXnDNIY9m8tfXvuK3fl2c1kQQpUErICcMbGjc/MMXrv+aIYf24nnZ67krMc+YeF3ZZYQMzUoFK1igRWQMyYmpSYl8uczsnlhWF+27y7gnMc/5YkP8yj0/2dNzQtFq/jjmIB1BxkTi47p0oy3bz6Ok7Kbc9/bSzj3yeksXb8j6LBCISRJwI8J2MCwMTGrSf0Unrj4SB4f0ptVW3ZzxiOf2FlBLQhFq2iXiBoTP844ohXv/vbHs4JfPzmdr+2soMaEolUs6Q5Ksu4gY+JCswap+84KVm/ZzRmPTOP+d5dYeeoaEIokkF/ozwTs6iBj4krJWcFZR7Tm0al5nPLgx3z0ddUeM2vKFopWsbC4mKQEISHBzgSMiTfNGqTywIU9GXdVX5IShaFjZ3PDuHms3253G1eHUCSBgiK18QBj4tzRnZvx1m+O5ZaTu/LuwvWcdP9HjPlk+b4zfVM1oWgZ8wuLbTzAmDogNSmRm07swrs3H0evDo35238WMvChj3l/0XorSFdFoUgCBUXFdnmoMXVIx2b1efaKo3j68qNAYNizuVw6ZjaL19kdx/srFC1joXUHGVPniAi/OLQ579x8HP97VjZfrtnG6Q9P47ZXvmTdNhsviFYoWsaComKSrXicMXVScmICVwzoxEd/OIHL+nfkpdxVHP/PD7j7jYV8vys/6PBiXiiSQH5RsZ0JGFPHNUpP4fazuzH1dydwxhGtGPPJco69dyoPvLuE7XsKgg4vZoWiZSwoKrZ7BIwJifZN03nggp68+9vjOP6QLB6Zmsex937AA//9mi12ZvAzoWgZC4rUuoOMCZnOzTN44uIj+c+Nx9CnUxMeeX8pR/9jKne+vpC123YHHV7MSAo6gNpQYN1BxoRW9zaZjL4sh6/X7+BfH37DszNW8PzMFfyqVxuuGNCJw1o1DDrEQFkSMMaEQtcWGTxwYU9+e3JXRk9bxsQ5q5iUu5o+nZowtH9HTunWIpTtREiSgJKWHL5frjHm59o1SefOQd255eSuTMpdxXMzVnL9uHm0bJjGkL7tOe/ItrRuVC/oMGtNSJJAMRlpodhVY0yUGqWnMOK4gxl2zEF8sHgDz85YwQP//ZoH3/uaYzo349zebTm1W0vqpSQGHWqNCkXLaLWDjDHlSUwQTspuwUnZLVi5eRcvz1vDK/NWc/PE+TRITeL0w1ty+uGtOPrgZqTUwUfUhiQJWNkIY0zlOjStzy0nd+XmE7swa/n3vDxvNW98sZZJuatpmJbESYe1YGD3lhzXNYu05LpxhhCaJGAF5Iwx0UpIEPof3JT+BzflrnO688nSTbz11TreW7SeVz5bQ1pyAn07NeXYLs04vmsWnZs3QCQ+25hwJIFCuzrIGFM1acmJ+7qLCoqKmfHNZqYu3sDHSzdy1xuLuOuNRbRsmEa/g5pwZIfG9O7QmENbNiQxTp5fElUSEJGBwMNAIvCUqv6j1PxU4DngSGAzcKGqrvDzbgOGAUXATar6TjTbrE4FxTYmYIw5cMmJCRzXNYvjumYBsHrLD0xbuolpSzfy6TebeW3+dwDUT0mkR7tGHNaqIYe0zOCwlg3p0qJBTHYhVZoERCQReBw4GVgNzBGRKaq6MGKxYcAWVe0sIoOBe4ELRSQbGAx0A1oD74lIV79OZdusNm5MID6ysjEmfrRtnM5FfdpzUZ/2qCqrt+xm3rdbmLtyC/NXbeWFmSvZ6x96kyBu+XZN6tG2kf+3cTpZGak0Tk+hSf0UGtdPJjWpdhNFNGcCfYA8VV0GICITgEFAZIM9CLjdv58MPCaug2wQMEFV9wLLRSTPb48otlltrDvIGFPTRIR2TdJp1ySdQT3bAFBUrKzcvIsl63awaN0Olm/axarvf+D9xevZtLPsOkZpyQmkJCaQmpzo/k1KoFmDVCZd079G4o4mCbQBVkV8Xg30LW8ZVS0UkW1AUz99Zql12/j3lW0TABEZAYwAaN++fRTh/twp3VqS3Trct4YbY2pfYoJwUFYDDspqwGmHt/rJvN35RazZ+gObduazZVc+W34oYMsP+WzbXUB+YTF7C4vZW1hEfmExDVJrbvg25geGVXUUMAogJyenSs+Pe/DCntUZkjHGHLB6KYl0bp5B5+bBxhFNH8kaoF3E57Z+WpnLiEgSkIkbIC5v3Wi2aYwxpoZFkwTmAF1EpJOIpOAGeqeUWmYKMNS/Pw+Yqu6pz1OAwSKSKiKdgC7A7Ci3aYwxpoZV2h3k+/hvAN7BXc45VlUXiMidQK6qTgHGAM/7gd/vcY06frlJuAHfQuB6VS0CKGub1b97xhhjKiLugD0+5OTkaG5ubtBhGGNM3BCRuaqaU958u27SGGNCzJKAMcaEmCUBY4wJMUsCxhgTYnE1MCwiG4GVVVy9GbCpGsOpSfEUK8RXvPEUK8RXvPEUK8RXvAcSawdVzSpvZlwlgQMhIrkVjZDHkniKFeIr3niKFeIr3niKFeIr3pqM1bqDjDEmxCwJGGNMiIUpCYwKOoD9EE+xQnzFG0+xQnzFG0+xQnzFW2OxhmZMwBhjzM+F6UzAGGNMKZYEjDEmxOp8EhCRgSKyRETyROTWWv7usSKyQUS+ipjWRET+KyJL/b+N/XQRkUd8nF+ISO+IdYb65ZeKyNCI6UeKyJd+nUf8Iz2rGms7EflARBaKyAIR+U2sxisiaSIyW0Q+97He4ad3EpFZfvsTfZlyfCnziX76LBHpGLGt2/z0JSJyasT0av+7EZFEEflMRP4Ty/GKyAr/e5ovIrl+Wsz9HURsr5GITBaRxSKySET6x2K8InKI/5mWvLaLyM2Bx6qqdfaFK1P9DXAQkAJ8DmTX4vcfB/QGvoqYdh9wq39/K3Cvf3868BYgQD9glp/eBFjm/23s3zf282b7ZcWve9oBxNoK6O3fZwBfA9mxGK9fv4F/nwzM8tudBAz20/8FXOvfXwf8y78fDEz077P930Qq0Mn/rSTW1N8NcAswDviP/xyT8QIrgGalpsXc30FEbM8CV/n3KUCjWI7XbzMRWAd0CDrWWmkMg3oB/YF3Ij7fBtxWyzF05KdJYAnQyr9vBSzx70cCF5VeDrgIGBkxfaSf1gpYHDH9J8tVQ9z/Bk6O9XiBdGAe7hnVm4Ck0r973HMr+vv3SX45Kf33ULJcTfzd4J6e9z7wS+A//vtjMl7KTgIx+XeAe4rhcvxFLrEeb8R2TgE+jYVY63p3UBt+/kD7NuUsW1taqOpa/34d0MK/Ly/WiqavLmP6AfPdD71wR9gxGa/vWpkPbAD+izsS3qqqhWVsf19Mfv42oGkV9uFAPAT8ESj2n5vGcLwKvCsic0VkhJ8Wk38HuDOijcDTvqvtKRGpH8PxlhgMjPfvA421rieBmKYuXcfUNboi0gB4GbhZVbdHzouleFW1SFV74o6w+wCHBhtR+UTkTGCDqs4NOpYoHaOqvYHTgOtF5LjImbH0d4A7U+oNPKmqvYBduC6VfWIsXvzYz9nAS6XnBRFrXU8CsfhA+/Ui0grA/7vBTy8v1oqmty1jepWJSDIuAbyoqq/EerwAqroV+ADXJdJIREoemRq5/X0x+fmZwOYq7ENVDQDOFpEVwARcl9DDsRqvqq7x/24AXsUl2Vj9O1gNrFbVWf7zZFxSiNV4wSXXeaq63n8ONtYD7duK5RfuKGEZ7pSxZMCsWy3H0JGfjgn8k58OAt3n35/BTweBZvvpTXB9no39aznQxM8rPQh0+gHEKcBzwEOlpsdcvEAW0Mi/rwdMA87EHVlFDrRe599fz08HWif599346UDrMtyAXY393QAn8OPAcMzFC9QHMiLeTwcGxuLfQUTM04BD/PvbfayxHO8E4IpY+T9Wa41hUC/cCPvXuD7jP9fyd48H1gIFuCOWYbi+3feBpcB7Eb88AR73cX4J5ERs50ogz78i/3hygK/8Oo9RanBsP2M9Bnca+gUw379Oj8V4gSOAz3ysXwF/9dMP8v8J8nANbKqfnuY/5/n5B0Vs688+niVEXElRU383/DQJxFy8PqbP/WtBybZi8e8gYns9gVz/9/AarmGMyXhxiXUzkBkxLdBYrWyEMcaEWF0fEzDGGFMBSwLGGBNilgSMMSbELAkYY0yIWRIwxpgQsyRgjDEhZknAGGNC7P8DeR7hlq0WNO4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def plot_lr():\n",
    "    num_warmup_steps=1000\n",
    "    num_training_steps=70000\n",
    "    lr = 0.01\n",
    "    res_list = []\n",
    "    for current_step in range(70000):\n",
    "        if current_step < num_warmup_steps:\n",
    "            res = float(current_step) / float(max(1, num_warmup_steps))\n",
    "            res_list.append(res * lr)\n",
    "            continue\n",
    "        progress = float(current_step - num_warmup_steps) / float(\n",
    "                    max(1, num_training_steps - num_warmup_steps)\n",
    "                )\n",
    "        res = 0.5 * (1.0 + math.cos(math.pi * float(0.5) * 2.0 * progress))\n",
    "        res_list.append(res * lr)\n",
    "\n",
    "    plt.plot(res_list)\n",
    "    plt.title(f'Trend of Learning Rate\\nnum_warmup_steps={num_warmup_steps}\\nnum_training_steps={num_training_steps}')\n",
    "    plt.show()\n",
    "\n",
    "plot_lr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-18T14:20:36.495194Z",
     "iopub.status.busy": "2022-12-18T14:20:36.494842Z",
     "iopub.status.idle": "2022-12-18T14:20:36.505978Z",
     "shell.execute_reply": "2022-12-18T14:20:36.504505Z",
     "shell.execute_reply.started": "2022-12-18T14:20:36.495157Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_cosine_schedule_with_warmup(\n",
    "    opt: Optimizer,\n",
    "    num_warmup_steps: int,\n",
    "    num_training_steps: int,\n",
    "    num_cycles: float = 0.5,\n",
    "    last_epoch: int = -1\n",
    "):\n",
    "    \"\"\"\n",
    "    创建一个学习率变化策略,\n",
    "    学习率跟随cosine值变化,\n",
    "    在warm up时间段内变化区间在:\n",
    "        0 -> 优化器设置的学习率 .\n",
    "    Args:\n",
    "        opt (Optimizer): 优化器类\n",
    "        num_warmup_steps (int): 多少步增加一下lr\n",
    "        num_training_steps (int): 总训练步骤\n",
    "        num_cycles (float, optional): 变化周期. 默认为 0.5.\n",
    "        last_epoch (int, optional): _description_. Defaults to -1.\n",
    "    \"\"\"\n",
    "    def lr_lambda(current_step):\n",
    "        # warmup\n",
    "        if current_step < num_warmup_steps:\n",
    "            return float(current_step) / float(max(1, num_warmup_steps))\n",
    "        # 衰减\n",
    "        progress = float(current_step - num_warmup_steps) / float(\n",
    "            max(1, num_training_steps - num_warmup_steps)\n",
    "        )\n",
    "        return max(\n",
    "            0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))\n",
    "        )\n",
    "    return LambdaLR(opt, lr_lambda, last_epoch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#   &#x2728; **训练部分**\n",
    "这部分和HW01 & HW03基本相同"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-18T14:21:59.858004Z",
     "iopub.status.busy": "2022-12-18T14:21:59.857664Z",
     "iopub.status.idle": "2022-12-18T14:21:59.874252Z",
     "shell.execute_reply": "2022-12-18T14:21:59.873181Z",
     "shell.execute_reply.started": "2022-12-18T14:21:59.857977Z"
    }
   },
   "outputs": [],
   "source": [
    "def trainer(train_loader, valid_loader, model, config, device, rest_net_flag=False):\n",
    "    # 对于分类任务, 我们常用cross-entropy评估模型表现.\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    # 初始化优化器\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate']) \n",
    "    if config['scheduler_flag']:\n",
    "        scheduler = get_cosine_schedule_with_warmup(optimizer, config['warmup_steps'], len(train_loader) * config['n_epochs'])\n",
    "    # 模型存储位置\n",
    "    save_path =  config['save_path']\n",
    "\n",
    "    writer = SummaryWriter()\n",
    "    if not os.path.isdir('./models'):\n",
    "        os.mkdir('./models')\n",
    "\n",
    "    n_epochs, best_loss, step, early_stop_count = config['n_epochs'], math.inf, 0, 0\n",
    "    for epoch in range(n_epochs):\n",
    "        model.train()\n",
    "        loss_record = []\n",
    "        train_accs = []\n",
    "        train_pbar = tqdm(train_loader, position=0, leave=True)\n",
    "\n",
    "        for x, y in train_pbar:\n",
    "            optimizer.zero_grad()             \n",
    "            x, y = x.to(device), y.to(device)  \n",
    "            pred = model(x)\n",
    "            loss = criterion(pred, y)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            if config['scheduler_flag']:\n",
    "                scheduler.step()\n",
    "            step += 1\n",
    "            acc = (pred.argmax(dim=-1) == y.to(device)).float().mean()\n",
    "            l_ = loss.detach().item()\n",
    "            loss_record.append(l_)\n",
    "            train_accs.append(acc.detach().item())\n",
    "            train_pbar.set_description(f'Epoch [{epoch+1}/{n_epochs}]')\n",
    "            train_pbar.set_postfix({'loss': f'{l_:.5f}', 'acc': f'{acc:.5f}'})\n",
    "        \n",
    "        \n",
    "        mean_train_acc = sum(train_accs) / len(train_accs)\n",
    "        mean_train_loss = sum(loss_record)/len(loss_record)\n",
    "        writer.add_scalar('Loss/train', mean_train_loss, step)\n",
    "        writer.add_scalar('ACC/train', mean_train_acc, step)\n",
    "        \n",
    "        model.eval() # 设置模型为评估模式\n",
    "        loss_record = []\n",
    "        test_accs = []\n",
    "        for x, y in valid_loader:\n",
    "            x, y = x.to(device), y.to(device)\n",
    "            with torch.no_grad():\n",
    "                pred = model(x)\n",
    "                loss = criterion(pred, y)\n",
    "                acc = (pred.argmax(dim=-1) == y.to(device)).float().mean()\n",
    "\n",
    "            loss_record.append(loss.item())\n",
    "            test_accs.append(acc.detach().item())\n",
    "            \n",
    "        mean_valid_acc = sum(test_accs) / len(test_accs)\n",
    "        mean_valid_loss = sum(loss_record)/len(loss_record)\n",
    "        print(f'Epoch [{epoch+1}/{n_epochs}]: Train loss: {mean_train_loss:.4f},acc: {mean_train_acc:.4f} Valid loss: {mean_valid_loss:.4f},acc: {mean_valid_acc:.4f} ')\n",
    "        writer.add_scalar('Loss/valid', mean_valid_loss, step)\n",
    "        writer.add_scalar('ACC/valid', mean_valid_acc, step)\n",
    "        if mean_valid_loss < best_loss:\n",
    "            best_loss = mean_valid_loss\n",
    "            torch.save(model.state_dict(), save_path) # 保存最优模型\n",
    "            print('Saving model with loss {:.3f}...'.format(best_loss))\n",
    "            early_stop_count = 0\n",
    "        else: \n",
    "            early_stop_count += 1\n",
    "\n",
    "        if early_stop_count >= config['early_stop']:\n",
    "            print('\\nModel is not improving, so we halt the training session.')\n",
    "            return"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# **参数设置**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-18T14:22:01.989539Z",
     "iopub.status.busy": "2022-12-18T14:22:01.989181Z",
     "iopub.status.idle": "2022-12-18T14:22:01.997891Z",
     "shell.execute_reply": "2022-12-18T14:22:01.996914Z",
     "shell.execute_reply.started": "2022-12-18T14:22:01.989501Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "Set env random_seed = 87\n"
     ]
    }
   ],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "config = {\n",
    "    'seed': 87,\n",
    "    'dataset_dir': \"../input/ml2022spring-hw4/Dataset\",\n",
    "    'n_epochs': 35,      \n",
    "    'batch_size': 64, \n",
    "    \n",
    "    'scheduler_flag': True,\n",
    "    'valid_steps': 2000,\n",
    "    'warmup_steps': 1000,\n",
    "    # 'total_steps': 70000, # len(train) * n_epochs\n",
    "    'learning_rate': 1e-3,          \n",
    "    'early_stop': 300,\n",
    "    'n_workers': 8,\n",
    "    'save_path': './models/model.ckpt'\n",
    "}\n",
    "print(device)\n",
    "all_seed(config['seed'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# **导入数据集**\n",
    "- 将数据集分割成训练集(90%)和验证集(10%).\n",
    "- 创建dataloader用于模型训练.\n",
    "- 用`pad_sequence`方法将一个batch中的数据都扩展成一样的长度(`collate_batch`)  \n",
    "\n",
    "    Example:\n",
    "```python\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "a = torch.ones(25, 40)\n",
    "b = torch.ones(22, 40)\n",
    "c = torch.ones(15, 40)\n",
    "pad_sequence([a, b, c], batch_first=True).size() # 都扩展成一样长\n",
    "torch.Size([3, 25, 40])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-18T14:22:06.056552Z",
     "iopub.status.busy": "2022-12-18T14:22:06.055630Z",
     "iopub.status.idle": "2022-12-18T14:22:06.289377Z",
     "shell.execute_reply": "2022-12-18T14:22:06.288297Z",
     "shell.execute_reply.started": "2022-12-18T14:22:06.056493Z"
    }
   },
   "outputs": [],
   "source": [
    "def collate_batch(batch):\n",
    "    # 将一个batch中的数据合并\n",
    "    \"\"\"Collate a batch of data.\"\"\"\n",
    "    mel, speaker = zip(*batch)\n",
    "    # 为了保持一个batch内的长度都是一样的所有需要进行padding, 同时设置batch的维度是最前面的一维\n",
    "    mel = pad_sequence(mel, batch_first=True, padding_value=-20)    # pad log 10^(-20) 一个很小的值\n",
    "    # mel: (batch size, length, 40)\n",
    "    return mel, torch.FloatTensor(speaker).long()\n",
    "\n",
    "\n",
    "data_dir = config['dataset_dir']\n",
    "dataset = myDataset(data_dir)\n",
    "speaker_num = dataset.get_speaker_number()\n",
    "speaker2id = dataset.speaker2id\n",
    "# 将数据拆分成训练集和验证集\n",
    "trainlen = int(0.9 * len(dataset))\n",
    "lengths = [trainlen, len(dataset) - trainlen]\n",
    "trainset, validset = random_split(dataset, lengths)\n",
    "testset = InferenceDataset(data_dir)\n",
    "\n",
    "train_loader = DataLoader(\n",
    "    trainset,\n",
    "    batch_size=config['batch_size'],\n",
    "    shuffle=True,\n",
    "    drop_last=True,\n",
    "    num_workers=config['n_workers'],\n",
    "    pin_memory=True,\n",
    "    collate_fn=collate_batch,\n",
    ")\n",
    "\n",
    "valid_loader = DataLoader(\n",
    "    validset,\n",
    "    batch_size=config['batch_size'],\n",
    "    num_workers=config['n_workers'],\n",
    "    drop_last=True,\n",
    "    pin_memory=True,\n",
    "    collate_fn=collate_batch,\n",
    ")\n",
    "\n",
    "\n",
    "test_loader = DataLoader(\n",
    "    testset,\n",
    "    batch_size=1,\n",
    "    num_workers=config['n_workers'],\n",
    "    shuffle=False,\n",
    "    drop_last=False,\n",
    "    pin_memory=True,\n",
    "    collate_fn=inference_collate_batch,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-17T01:21:04.154130Z",
     "iopub.status.busy": "2022-12-17T01:21:04.153501Z",
     "iopub.status.idle": "2022-12-17T01:21:04.158714Z",
     "shell.execute_reply": "2022-12-17T01:21:04.157635Z",
     "shell.execute_reply.started": "2022-12-17T01:21:04.154079Z"
    }
   },
   "source": [
    "#  &#x1F4CC; **开始训练！**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-18T14:22:11.082139Z",
     "iopub.status.busy": "2022-12-18T14:22:11.081533Z",
     "iopub.status.idle": "2022-12-18T14:40:52.298707Z",
     "shell.execute_reply": "2022-12-18T14:40:52.295153Z",
     "shell.execute_reply.started": "2022-12-18T14:22:11.082104Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [1/35]: 100%|██████████| 796/796 [00:56<00:00, 14.21it/s, loss=4.99485, acc=0.09375]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/35]: Train loss: 5.9355,acc: 0.0156 Valid loss: 4.9507,acc: 0.0595 \n",
      "Saving model with loss 4.951...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [2/35]: 100%|██████████| 796/796 [00:27<00:00, 29.22it/s, loss=3.60158, acc=0.21875]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [2/35]: Train loss: 4.4725,acc: 0.1081 Valid loss: 3.9621,acc: 0.1795 \n",
      "Saving model with loss 3.962...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [3/35]: 100%|██████████| 796/796 [00:28<00:00, 28.36it/s, loss=3.16526, acc=0.28125]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [3/35]: Train loss: 3.7922,acc: 0.1974 Valid loss: 3.4876,acc: 0.2493 \n",
      "Saving model with loss 3.488...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [4/35]: 100%|██████████| 796/796 [00:28<00:00, 27.79it/s, loss=3.33363, acc=0.29688]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [4/35]: Train loss: 3.4181,acc: 0.2598 Valid loss: 3.2367,acc: 0.2926 \n",
      "Saving model with loss 3.237...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [5/35]: 100%|██████████| 796/796 [00:28<00:00, 28.13it/s, loss=3.20587, acc=0.31250]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [5/35]: Train loss: 3.1651,acc: 0.3075 Valid loss: 3.0073,acc: 0.3425 \n",
      "Saving model with loss 3.007...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [6/35]: 100%|██████████| 796/796 [00:29<00:00, 26.65it/s, loss=2.82753, acc=0.28125]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [6/35]: Train loss: 2.9686,acc: 0.3386 Valid loss: 2.8559,acc: 0.3695 \n",
      "Saving model with loss 2.856...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [7/35]: 100%|██████████| 796/796 [00:30<00:00, 26.47it/s, loss=3.00548, acc=0.32812]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [7/35]: Train loss: 2.8109,acc: 0.3689 Valid loss: 2.6610,acc: 0.4125 \n",
      "Saving model with loss 2.661...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [8/35]: 100%|██████████| 796/796 [00:29<00:00, 26.75it/s, loss=2.59602, acc=0.40625]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [8/35]: Train loss: 2.6778,acc: 0.3940 Valid loss: 2.6267,acc: 0.4190 \n",
      "Saving model with loss 2.627...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [9/35]: 100%|██████████| 796/796 [00:29<00:00, 27.35it/s, loss=2.58245, acc=0.37500]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [9/35]: Train loss: 2.5649,acc: 0.4128 Valid loss: 2.4476,acc: 0.4528 \n",
      "Saving model with loss 2.448...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [10/35]: 100%|██████████| 796/796 [00:28<00:00, 28.17it/s, loss=2.12685, acc=0.45312]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [10/35]: Train loss: 2.4754,acc: 0.4308 Valid loss: 2.3831,acc: 0.4732 \n",
      "Saving model with loss 2.383...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [11/35]: 100%|██████████| 796/796 [00:27<00:00, 28.53it/s, loss=2.19862, acc=0.53125]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [11/35]: Train loss: 2.3947,acc: 0.4504 Valid loss: 2.2878,acc: 0.4892 \n",
      "Saving model with loss 2.288...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [12/35]: 100%|██████████| 796/796 [00:28<00:00, 28.19it/s, loss=2.14563, acc=0.46875]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [12/35]: Train loss: 2.3208,acc: 0.4597 Valid loss: 2.3081,acc: 0.4828 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [13/35]: 100%|██████████| 796/796 [00:28<00:00, 27.87it/s, loss=2.26823, acc=0.39062]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [13/35]: Train loss: 2.2543,acc: 0.4759 Valid loss: 2.1940,acc: 0.5163 \n",
      "Saving model with loss 2.194...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [14/35]: 100%|██████████| 796/796 [00:28<00:00, 28.27it/s, loss=2.28814, acc=0.45312]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [14/35]: Train loss: 2.1906,acc: 0.4887 Valid loss: 2.1073,acc: 0.5268 \n",
      "Saving model with loss 2.107...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [15/35]: 100%|██████████| 796/796 [00:27<00:00, 28.64it/s, loss=2.36237, acc=0.48438]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [15/35]: Train loss: 2.1317,acc: 0.5014 Valid loss: 2.0983,acc: 0.5256 \n",
      "Saving model with loss 2.098...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [16/35]: 100%|██████████| 796/796 [00:27<00:00, 28.74it/s, loss=1.77035, acc=0.57812]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [16/35]: Train loss: 2.0813,acc: 0.5101 Valid loss: 2.0207,acc: 0.5478 \n",
      "Saving model with loss 2.021...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [17/35]: 100%|██████████| 796/796 [00:27<00:00, 28.58it/s, loss=2.45166, acc=0.56250]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [17/35]: Train loss: 2.0341,acc: 0.5211 Valid loss: 2.0007,acc: 0.5471 \n",
      "Saving model with loss 2.001...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [18/35]: 100%|██████████| 796/796 [00:28<00:00, 27.98it/s, loss=2.22434, acc=0.48438]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [18/35]: Train loss: 2.0002,acc: 0.5276 Valid loss: 1.9689,acc: 0.5518 \n",
      "Saving model with loss 1.969...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [19/35]: 100%|██████████| 796/796 [00:28<00:00, 28.41it/s, loss=2.04935, acc=0.46875]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [19/35]: Train loss: 1.9536,acc: 0.5394 Valid loss: 1.9298,acc: 0.5655 \n",
      "Saving model with loss 1.930...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [20/35]: 100%|██████████| 796/796 [00:27<00:00, 28.48it/s, loss=1.83221, acc=0.56250]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [20/35]: Train loss: 1.9053,acc: 0.5431 Valid loss: 1.8723,acc: 0.5705 \n",
      "Saving model with loss 1.872...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [21/35]: 100%|██████████| 796/796 [00:27<00:00, 28.91it/s, loss=1.97107, acc=0.57812]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [21/35]: Train loss: 1.8743,acc: 0.5534 Valid loss: 1.8888,acc: 0.5716 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [22/35]: 100%|██████████| 796/796 [00:27<00:00, 28.66it/s, loss=1.53612, acc=0.67188]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [22/35]: Train loss: 1.8438,acc: 0.5607 Valid loss: 1.8568,acc: 0.5856 \n",
      "Saving model with loss 1.857...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [23/35]: 100%|██████████| 796/796 [00:27<00:00, 28.50it/s, loss=1.65329, acc=0.60938]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [23/35]: Train loss: 1.8093,acc: 0.5662 Valid loss: 1.8053,acc: 0.5939 \n",
      "Saving model with loss 1.805...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [24/35]: 100%|██████████| 796/796 [00:28<00:00, 28.29it/s, loss=1.99409, acc=0.48438]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [24/35]: Train loss: 1.7781,acc: 0.5748 Valid loss: 1.7939,acc: 0.5943 \n",
      "Saving model with loss 1.794...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [25/35]: 100%|██████████| 796/796 [00:27<00:00, 28.71it/s, loss=1.45089, acc=0.67188]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [25/35]: Train loss: 1.7634,acc: 0.5774 Valid loss: 1.7809,acc: 0.6044 \n",
      "Saving model with loss 1.781...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [26/35]: 100%|██████████| 796/796 [00:27<00:00, 28.74it/s, loss=1.96837, acc=0.56250]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [26/35]: Train loss: 1.7360,acc: 0.5804 Valid loss: 1.7456,acc: 0.6010 \n",
      "Saving model with loss 1.746...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [27/35]: 100%|██████████| 796/796 [00:28<00:00, 28.40it/s, loss=1.41114, acc=0.71875]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [27/35]: Train loss: 1.7184,acc: 0.5856 Valid loss: 1.7315,acc: 0.6083 \n",
      "Saving model with loss 1.731...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [28/35]: 100%|██████████| 796/796 [00:28<00:00, 28.40it/s, loss=2.00252, acc=0.50000]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [28/35]: Train loss: 1.6994,acc: 0.5923 Valid loss: 1.7210,acc: 0.6106 \n",
      "Saving model with loss 1.721...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [29/35]: 100%|██████████| 796/796 [00:28<00:00, 28.17it/s, loss=2.06908, acc=0.45312]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [29/35]: Train loss: 1.6816,acc: 0.5956 Valid loss: 1.7263,acc: 0.6069 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [30/35]: 100%|██████████| 796/796 [00:27<00:00, 29.30it/s, loss=1.20732, acc=0.67188]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [30/35]: Train loss: 1.6613,acc: 0.5998 Valid loss: 1.6684,acc: 0.6072 \n",
      "Saving model with loss 1.668...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [31/35]: 100%|██████████| 796/796 [00:27<00:00, 29.45it/s, loss=1.84132, acc=0.59375]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [31/35]: Train loss: 1.6512,acc: 0.6004 Valid loss: 1.6914,acc: 0.6199 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [32/35]: 100%|██████████| 796/796 [00:27<00:00, 29.24it/s, loss=1.58765, acc=0.62500]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [32/35]: Train loss: 1.6466,acc: 0.6020 Valid loss: 1.7334,acc: 0.6051 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [33/35]: 100%|██████████| 796/796 [00:27<00:00, 28.56it/s, loss=1.75043, acc=0.62500]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [33/35]: Train loss: 1.6376,acc: 0.6050 Valid loss: 1.7110,acc: 0.6175 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [34/35]: 100%|██████████| 796/796 [00:27<00:00, 29.09it/s, loss=1.57912, acc=0.65625]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [34/35]: Train loss: 1.6399,acc: 0.6036 Valid loss: 1.6946,acc: 0.6211 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch [35/35]: 100%|██████████| 796/796 [00:27<00:00, 28.73it/s, loss=1.64054, acc=0.62500]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [35/35]: Train loss: 1.6312,acc: 0.6090 Valid loss: 1.7201,acc: 0.6143 \n"
     ]
    }
   ],
   "source": [
    "model = Classifier(\n",
    "    input_dim=40,  # n_mel\n",
    "    d_model=80,\n",
    "    n_spks=600, \n",
    "    dropout=0.1\n",
    ").to(device)\n",
    "trainer(train_loader, valid_loader, model, config, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# **测试并生成预测结果的csv**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-18T14:40:52.301710Z",
     "iopub.status.busy": "2022-12-18T14:40:52.301296Z",
     "iopub.status.idle": "2022-12-18T14:41:17.215394Z",
     "shell.execute_reply": "2022-12-18T14:41:17.214208Z",
     "shell.execute_reply.started": "2022-12-18T14:40:52.301669Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8000/8000 [00:24<00:00, 321.50it/s]\n"
     ]
    }
   ],
   "source": [
    "model_best = Classifier().to(device)\n",
    "model_best.load_state_dict(torch.load(config['save_path']))\n",
    "model_best.eval()\n",
    "mapping_path = Path(data_dir) / \"mapping.json\"\n",
    "mapping = json.load(mapping_path.open())\n",
    "pred_id = []\n",
    "pred_final_cls = []\n",
    "with torch.no_grad():\n",
    "    for name, data in tqdm(test_loader):\n",
    "        test_pred = model_best(data.to(device))\n",
    "        test_label = np.argmax(test_pred.cpu().data.numpy(), axis=1)\n",
    "        pred_id += name\n",
    "        pred_final_cls += [mapping[\"id2speaker\"][str(test_label[0])]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-18T14:41:17.217459Z",
     "iopub.status.busy": "2022-12-18T14:41:17.217102Z",
     "iopub.status.idle": "2022-12-18T14:41:17.255614Z",
     "shell.execute_reply": "2022-12-18T14:41:17.254748Z",
     "shell.execute_reply.started": "2022-12-18T14:41:17.217429Z"
    }
   },
   "outputs": [],
   "source": [
    "df = pd.DataFrame()\n",
    "df[\"Id\"] = pred_id\n",
    "df[\"Category\"] = pred_final_cls\n",
    "df.to_csv(\"submission.csv\",index = False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 贡献者\n",
    "\n",
    "孙成超\n",
    "- Github: https://github.com/scchy\n",
    "- Email: hyscc1994@foxmail.com"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
